"""PPO Agent for subgraph isomorphism."""

import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from typing import Dict, List, Tuple, Optional
from dataclasses import dataclass
from collections import deque
import torch.nn.functional as F

from .models import PolicyValueNetwork, GraphEncoder, StateEncoder
from .environment import SubgraphIsomorphismEnv


@dataclass
class PPOConfig:
    learning_rate: float = 3e-4
    clip_ratio: float = 0.2
    value_coef: float = 0.5
    entropy_coef: float = 0.01
    max_grad_norm: float = 0.5
    ppo_epochs: int = 4
    batch_size: int = 64
    buffer_size: int = 2048
    gamma: float = 0.99
    gae_lambda: float = 0.95


class PPOBuffer:
    """Experience buffer for PPO."""
    
    def __init__(self, capacity: int, gamma: float, gae_lambda: float):
        self.capacity = capacity
        self.gamma = gamma
        self.gae_lambda = gae_lambda
        
        self.states = []
        self.actions = []
        self.rewards = []
        self.values = []
        self.log_probs = []
        self.dones = []
        self.advantages = []
        self.returns = []
        
        self.ptr = 0
        self.size = 0
    
    def store(self, state, action, reward, value, log_prob, done):
        """Store experience tuple."""
        if self.size < self.capacity:
            self.states.append(state)
            self.actions.append(action)
            self.rewards.append(reward)
            self.values.append(value)
            self.log_probs.append(log_prob)
            self.dones.append(done)
            self.size += 1
        else:
            self.states[self.ptr] = state
            self.actions[self.ptr] = action
            self.rewards[self.ptr] = reward
            self.values[self.ptr] = value
            self.log_probs[self.ptr] = log_prob
            self.dones[self.ptr] = done
        
        self.ptr = (self.ptr + 1) % self.capacity
    
    def compute_advantages(self, last_value: float = 0.0):
        """Compute GAE advantages and returns."""
        values = self.values + [last_value]
        advantages = []
        advantage = 0
        
        for t in reversed(range(len(self.rewards))):
            delta = self.rewards[t] + self.gamma * values[t + 1] * (1 - self.dones[t]) - values[t]
            advantage = delta + self.gamma * self.gae_lambda * (1 - self.dones[t]) * advantage
            advantages.insert(0, advantage)
        
        self.advantages = advantages
        self.returns = [adv + val for adv, val in zip(self.advantages, self.values)]
    
    def get_batch(self, batch_size: int):
        """Get random batch of experiences."""
        indices = np.random.choice(self.size, batch_size, replace=False)
        
        batch = {
            'states': [self.states[i] for i in indices],
            'actions': torch.tensor([self.actions[i] for i in indices], dtype=torch.long),
            'old_log_probs': torch.tensor([self.log_probs[i] for i in indices], dtype=torch.float32),
            'advantages': torch.tensor([self.advantages[i] for i in indices], dtype=torch.float32),
            'returns': torch.tensor([self.returns[i] for i in indices], dtype=torch.float32)
        }
        
        return batch
    
    def clear(self):
        """Clear buffer."""
        self.states = []
        self.actions = []
        self.rewards = []
        self.values = []
        self.log_probs = []
        self.dones = []
        self.advantages = []
        self.returns = []
        self.ptr = 0
        self.size = 0


class PPOAgent:
    """PPO agent for subgraph isomorphism."""
    
    def __init__(
        self,
        config: PPOConfig,
        node_features: int = 16,
        hidden_dim: int = 128,
        max_actions: int = 1000,
        device: str = 'cpu'
    ):
        self.config = config
        self.device = device
        
        # Initialize models
        self.encoder = GraphEncoder(
            node_features=node_features,
            hidden_dim=hidden_dim,
            num_layers=2
        )
        
        self.policy_value_net = PolicyValueNetwork(
            encoder=self.encoder,
            hidden_dim=hidden_dim,
            max_actions=max_actions
        ).to(device)
        
        self.optimizer = optim.Adam(
            self.policy_value_net.parameters(),
            lr=config.learning_rate
        )
        
        self.state_encoder = StateEncoder(node_feature_dim=node_features)
        
        # Experience buffer
        self.buffer = PPOBuffer(
            capacity=config.buffer_size,
            gamma=config.gamma,
            gae_lambda=config.gae_lambda
        )
        
        # Training stats
        self.training_stats = {
            'policy_loss': [],
            'value_loss': [],
            'entropy': [],
            'total_loss': []
        }
    
    def select_action(self, observation: Dict, deterministic: bool = False):
        """Select action using current policy."""
        state_data = self.state_encoder.encode_state(observation).to(self.device)
        
        # Get valid actions
        valid_actions = self._get_valid_actions_mask(observation)
        
        with torch.no_grad():
            logits, value = self.policy_value_net(state_data, valid_actions)
            
            if deterministic:
                action = torch.argmax(logits, dim=-1)
                log_prob = F.log_softmax(logits, dim=-1)[0, action]
            else:
                probs = F.softmax(logits, dim=-1)
                action_dist = torch.distributions.Categorical(probs)
                action = action_dist.sample()
                log_prob = action_dist.log_prob(action)
        
        return action.item(), log_prob.item(), value.item()
    
    def _get_valid_actions_mask(self, observation: Dict) -> torch.Tensor:
        """Get mask for valid actions."""
        frontier_size = len(observation['frontier'])
        max_actions = self.policy_value_net.max_actions
        
        mask = torch.zeros(1, max_actions, dtype=torch.bool, device=self.device)
        # Valid frontier actions + terminate action
        mask[0, :frontier_size + 1] = True
        
        return mask
    
    def store_experience(self, state, action, reward, value, log_prob, done):
        """Store experience in buffer."""
        self.buffer.store(state, action, reward, value, log_prob, done)
    
    def update(self) -> Dict[str, float]:
        """Update policy using PPO."""
        # Compute advantages
        self.buffer.compute_advantages()
        
        # Normalize advantages
        advantages = torch.tensor(self.buffer.advantages, dtype=torch.float32)
        advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)
        self.buffer.advantages = advantages.tolist()
        
        # Update for multiple epochs
        total_policy_loss = 0
        total_value_loss = 0
        total_entropy = 0
        
        for epoch in range(self.config.ppo_epochs):
            batch = self.buffer.get_batch(self.config.batch_size)
            
            # Convert states to batch
            state_batch = self._batch_states(batch['states'])
            
            # Forward pass
            logits, values = self.policy_value_net(state_batch)
            
            # Current log probabilities
            probs = F.softmax(logits, dim=-1)
            action_dist = torch.distributions.Categorical(probs)
            log_probs = action_dist.log_prob(batch['actions'])
            entropy = action_dist.entropy().mean()
            
            # PPO loss computation
            ratio = torch.exp(log_probs - batch['old_log_probs'])
            
            surr1 = ratio * batch['advantages']
            surr2 = torch.clamp(
                ratio,
                1 - self.config.clip_ratio,
                1 + self.config.clip_ratio
            ) * batch['advantages']
            
            policy_loss = -torch.min(surr1, surr2).mean()
            
            # Value loss
            value_loss = F.mse_loss(values.squeeze(), batch['returns'])
            
            # Total loss
            loss = (
                policy_loss +
                self.config.value_coef * value_loss -
                self.config.entropy_coef * entropy
            )
            
            # Backward pass
            self.optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(
                self.policy_value_net.parameters(),
                self.config.max_grad_norm
            )
            self.optimizer.step()
            
            total_policy_loss += policy_loss.item()
            total_value_loss += value_loss.item()
            total_entropy += entropy.item()
        
        # Clear buffer
        self.buffer.clear()
        
        # Update stats
        stats = {
            'policy_loss': total_policy_loss / self.config.ppo_epochs,
            'value_loss': total_value_loss / self.config.ppo_epochs,
            'entropy': total_entropy / self.config.ppo_epochs,
            'total_loss': (total_policy_loss + total_value_loss) / self.config.ppo_epochs
        }
        
        for key, value in stats.items():
            self.training_stats[key].append(value)
        
        return stats
    
    def _batch_states(self, states: List) -> torch.Tensor:
        """Convert list of states to batched tensor."""
        # This is a simplified version - in practice you'd need proper batching
        # for graph data using PyTorch Geometric's Batch class
        state_data = [self.state_encoder.encode_state(state) for state in states]
        # For now, return first state (would need proper batching implementation)
        return state_data[0].to(self.device)
    
    def save(self, filepath: str):
        """Save model checkpoint."""
        torch.save({
            'policy_value_net': self.policy_value_net.state_dict(),
            'optimizer': self.optimizer.state_dict(),
            'training_stats': self.training_stats
        }, filepath)
    
    def load(self, filepath: str):
        """Load model checkpoint."""
        checkpoint = torch.load(filepath, map_location=self.device)
        self.policy_value_net.load_state_dict(checkpoint['policy_value_net'])
        self.optimizer.load_state_dict(checkpoint['optimizer'])
        self.training_stats = checkpoint['training_stats']